#include <torch/extension.h>

#include <vector>

#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
  CHECK_CUDA(x);       \
  CHECK_CONTIGUOUS(x)

std::vector<torch::Tensor> transducer_loss_cuda_forward(torch::Tensor x, torch::Tensor label, torch::Tensor audLen,
                                                        torch::Tensor txtLen, torch::Tensor batchOffset, int maxFLen,
                                                        int blankIdx, int opt, bool packedInput);

torch::Tensor transducer_loss_cuda_backward(torch::Tensor x, torch::Tensor lossGrad, torch::Tensor alpha,
                                            torch::Tensor beta, torch::Tensor audLen, torch::Tensor txtLen,
                                            torch::Tensor label, torch::Tensor batchOffset, int maxFLen, int blankIdx,
                                            int opt, bool fuseSoftmaxBackward, bool packedInput);

std::vector<torch::Tensor> transducer_loss_forward(torch::Tensor x, torch::Tensor label, torch::Tensor fLen,
                                                   torch::Tensor yLen, torch::Tensor batchOffset, int maxFLen,
                                                   int blankIdx, int opt, bool packedInput) {
  CHECK_INPUT(x);
  CHECK_INPUT(label);
  CHECK_INPUT(fLen);
  CHECK_INPUT(yLen);
  if (packedInput) CHECK_INPUT(batchOffset);
  return transducer_loss_cuda_forward(x, label, fLen, yLen, batchOffset, maxFLen, blankIdx, opt, packedInput);
}

torch::Tensor transducer_loss_backward(torch::Tensor x, torch::Tensor lossGrad, torch::Tensor alpha, torch::Tensor beta,
                                       torch::Tensor fLen, torch::Tensor yLen, torch::Tensor label,
                                       torch::Tensor batchOffset, int maxFLen, int blankIdx, int opt,
                                       bool fuseSoftmaxBackward, bool packedInput) {
  CHECK_INPUT(x);
  CHECK_INPUT(label);
  CHECK_INPUT(lossGrad);
  CHECK_INPUT(alpha);
  CHECK_INPUT(beta);
  CHECK_INPUT(fLen);
  CHECK_INPUT(yLen);
  if (packedInput) CHECK_INPUT(batchOffset);

  return transducer_loss_cuda_backward(x, lossGrad, alpha, beta, fLen, yLen, label, batchOffset, maxFLen, blankIdx, opt,
                                       fuseSoftmaxBackward, packedInput);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &transducer_loss_forward, "transducer loss forward (CUDA)",
        py::call_guard<py::gil_scoped_release>());
  m.def("backward", &transducer_loss_backward, "transducer loss backward (CUDA)",
        py::call_guard<py::gil_scoped_release>());
}
